/**
 * Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*!
 * \file flash_attention_score_split_b_bn2gs1s2d_db_high_performance.h
 * \brief
 */
#ifndef FLASH_ATTENTION_SCORE_SPLIT_B_B_N2_G_S1_S2_DB_PERFORMANCE_H
#define FLASH_ATTENTION_SCORE_SPLIT_B_B_N2_G_S1_S2_DB_PERFORMANCE_H
#include "flash_attention_score_common.h"
#include "kernel_data_copy_transpose.h"
#include "kernel_operator.h"
#include "kernel_tiling/kernel_tiling.h"
#include "lib/matmul_intf.h"

struct SplitBExtraInfo {
    int64_t boIdx;
    int64_t biN2GoIdx;
    int64_t s1oIdx;
    int64_t taskId;
    int64_t vecS1BaseSize; // S1基本块，保证vecS1BaseSize * S2 <= 32K
    int64_t vecS1TailSize; // S1尾块
    int64_t s2AlignSize; // S2 基本块大小
    int64_t s2AlignBlockSize; // S2 基本块32对齐大小
    int64_t s1Vec2BaseSize;
    int64_t s1Vec2BaseTailSize;
    int64_t s1Vec2OuterSize;
    int64_t qCoreOffset;
    int64_t softmaxCopyOutLimit;
    int64_t softmaxCopyOutSize;
    int64_t softmaxOutOffset = 0;
};

// INPUT_T - means data type for input
// T       - means data type when calc
template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T = INPUT_T,  \
    bool isBasicBlock = false, LayoutMode layout = LayoutMode::BNGS1S2>
class FlashAttentionScoreSplitBBn2gs1s2DBHighPerf
{
public:
  __aicore__ inline FlashAttentionScoreSplitBBn2gs1s2DBHighPerf(){};
  __aicore__ inline void Init(__gm__ uint8_t* query, __gm__ uint8_t* key, __gm__ uint8_t* value, __gm__ uint8_t* pse,
                              __gm__ uint8_t* dropMask, __gm__ uint8_t* paddingMask, __gm__ uint8_t* prefix,
                              __gm__ uint8_t* attenMask, __gm__ uint8_t* softmaxMax, __gm__ uint8_t* softmaxSum,
                              __gm__ uint8_t* softmaxOut, __gm__ uint8_t* attentionOut, __gm__ uint8_t* workspace,
                              const FlashAttentionScoreGeneralTilingData* __restrict tiling, TPipe* tPipe);
  __aicore__ inline void Process();

  // define batchmatmul
  using a1Type = MatmulType<TPosition::GM, CubeFormat::ND, INPUT_T, false, layout>;
  using b1Type = MatmulType<TPosition::GM, CubeFormat::ND, INPUT_T, true, layout>;
  using bias1Type = MatmulType<TPosition::GM, CubeFormat::ND, float, false, LayoutMode::BNGS1S2>;
  using c1Type = MatmulType<TPosition::GM, CubeFormat::ND, T, false, LayoutMode::BNGS1S2>;
  matmul::Matmul<a1Type, b1Type, c1Type, bias1Type> bmm1;

  using a2Type = MatmulType<TPosition::GM, CubeFormat::ND, INPUT_T, false, LayoutMode::BNGS1S2>;
  using b2Type = MatmulType<TPosition::GM, CubeFormat::ND, INPUT_T, false, layout>;
  using bias2Type = MatmulType<TPosition::GM, CubeFormat::ND, float, false, layout>;
  using c2Type = MatmulType<TPosition::GM, CubeFormat::ND, T, false, LayoutMode::BNGS1S2>;
  matmul::Matmul<a2Type, b2Type, c2Type, bias2Type> bmm2;

protected:
  __aicore__ inline void InitInput(__gm__ uint8_t* query, __gm__ uint8_t* key, __gm__ uint8_t* value,
                                  __gm__ uint8_t* pse, __gm__ uint8_t* dropMask, __gm__ uint8_t* paddingMask,
                                  __gm__ uint8_t* prefix,
                                  __gm__ uint8_t* attenMask, __gm__ uint8_t* softmaxMax, __gm__ uint8_t* softmaxSum,
                                  __gm__ uint8_t* softmaxOut, __gm__ uint8_t* attentionOut, __gm__ uint8_t* workspace,
                                  const FlashAttentionScoreGeneralTilingData* __restrict tiling, TPipe* tPipe);
  __aicore__ inline void WaitBmm1Result();
  __aicore__ inline void WaitBmm2Result();
  __aicore__ inline void IterateBmm1(SplitBExtraInfo& extraInfo);
  __aicore__ inline void IterateBmm2(SplitBExtraInfo& extraInfo);
  __aicore__ inline void BoolCopyIn(LocalTensor<uint8_t> &dstTensor, GlobalTensor<uint8_t> &srcTensor, int64_t offset,
                                    SplitBExtraInfo& extraInfo, int32_t s2Size, int64_t totalS2Size);
  __aicore__ inline int64_t ComputeQCoreOffset(SplitBExtraInfo& extraInfo);
  __aicore__ inline int64_t ComputeKVCoreOffset(SplitBExtraInfo& extraInfo);
  __aicore__ inline void SetExtraInfo(SplitBExtraInfo& extraInfo, int64_t taskId, int64_t multiCoreInnerIdx);
  __aicore__ inline void SetTiling(const FlashAttentionScoreGeneralTilingData* __restrict tilingData);
  __aicore__ inline void InitBuffer();
  __aicore__ inline void CalBatchSize();
  __aicore__ inline void ComputeConstexpr();
  __aicore__ inline void RefreshConstexpr();
  __aicore__ inline void ComputeAxisIdx(int64_t multiCoreInnerIdx);
  __aicore__ inline void ProcessVec1(SplitBExtraInfo& extraInfo);
  __aicore__ inline void ProcessVec2(SplitBExtraInfo &extraInfo);
  __aicore__ inline void CopyInAndCastPse(LocalTensor<T> &castTensor, SplitBExtraInfo &extraInfo, int64_t loopIdx);
  __aicore__ inline int64_t ComputePseOffset(SplitBExtraInfo &extraInfo);
  __aicore__ inline void CopyInAttenMask(SplitBExtraInfo& extraInfo, int64_t maskOffset);
  __aicore__ inline int64_t ComputeAttenMaskOffset(SplitBExtraInfo& extraInfo);
  __aicore__ inline int64_t ComputeOffsetForNoCompress(SplitBExtraInfo &extraInfo);
  __aicore__ inline void GetBmm1Result(SplitBExtraInfo& extraInfo, LocalTensor<T>& bmm1ResUb, int64_t loopIdx);
  __aicore__ inline void ComputePse(LocalTensor<T> &pseUb, LocalTensor<T> &bmm1ResUb, SplitBExtraInfo &extraInfo);
  __aicore__ inline void ComputeAttenMask(SplitBExtraInfo& extraInfo, LocalTensor<T>& bmm1ResUb, const uint8_t maskType);
  __aicore__ inline void CopyInDropMask(SplitBExtraInfo& extraInfo);
  __aicore__ inline int64_t ComputeDropOffset(SplitBExtraInfo& extraInfo);
  __aicore__ inline void SoftMaxCompute(SplitBExtraInfo &extraInfo, LocalTensor<T> &srcTensor, int64_t loopIdx);
  __aicore__ inline void ComputeDropMask(SplitBExtraInfo &extraInfo, LocalTensor<T> &bmm1ResUb, event_t eventIdMte2ToV);
  __aicore__ inline void DataCopyIn(LocalTensor<INPUT_T> &dstTensor, GlobalTensor<INPUT_T> &srcTensor, int64_t offset,
                                    int32_t s1Size, int32_t alignS2Size);
  __aicore__ inline void Bmm2ResultDiv(SplitBExtraInfo &extraInfo, int64_t vec2S1Idx, LocalTensor<T> &bmm2Res,
                                       LocalTensor<T> &sumTensor, int64_t s1Vec2BaseSize);
  __aicore__ inline void Bmm2DataCopyOut(SplitBExtraInfo &extraInfo, int64_t vec2S1Idx, LocalTensor<T> &bmm2Res,
                                         LocalTensor<INPUT_T> &attentionOut);

    // 构建dataCopyTranspose参数
    CopyTransposeTiling dataCopyTiling;

    uint32_t s1BaseSize;
    uint32_t s1BaseTailSize;
    uint32_t s2BaseSize;
    uint32_t s1CBaseSize;
    uint32_t s2CBaseSize;
    uint32_t dSize;
    uint32_t dBaseSize;
    uint32_t s1Size;
    uint32_t s2Size;

    // sparse 用函数
    __aicore__ inline void GetS1LoopRange(int64_t &multiCoreInnerOffset, int64_t &multiCoreInnerLimit);
    __aicore__ inline void GetS2LoopRange();

    // sparse 用参数
    int64_t s2StartIdx;
    int64_t s2EndIdx;
    int64_t s2LoopStart;
    int64_t tensorABatchSize;
    int64_t tensorBBatchSize;

    // s1方向Vec计算长度
    uint32_t s1VecSize;
    // s2方向的尾块，不包含N:1的配比
    int32_t s2TailSize;

    // k和v矩阵在核间的offset
    int64_t kCoreOffset;
    int64_t qCoreOffset;

    // 资源分配
    TBuf<> maskTBufPing;            // 11K
    TBuf<> maskTBufPong;            // 11K
    TBuf<> pseTBuf;                 // 16K
    TBuf<> stage1PingBuf;           // 32K
    TBuf<> stage1PongBuf;           // 32K
    TBuf<> softmaxSumPingBuf;       // 8K
    TBuf<> softmaxSumPongBuf;       // 8K
    TBuf<> softmaxMaxPingBuf;       // 8K
    TBuf<> softmaxMaxPongBuf;       // 8K
    TBuf<> softmaxExpBuf;           // 8K
    TBuf<> commonTBuf;              // 32k
    TBuf<> vecOut;                  // 16k

    GlobalTensor<T> mm1ResPing;
    GlobalTensor<T> mm1ResPong;
    GlobalTensor<T> mm2ResPing;
    GlobalTensor<T> mm2ResPong;
    GlobalTensor<INPUT_T> stage1ResPing;
    GlobalTensor<INPUT_T> stage1ResPong;
    GlobalTensor <uint8_t> dropoutWorkspaceGm;

    // 轴的乘积
    int64_t gS1o;
    int64_t n2GS1o;
    int64_t s1D;
    int64_t gS1D;
    int64_t n2GS1D;
    int64_t s2D;
    int64_t n2S2D;
    int64_t s1S2;
    int64_t gS1S2;
    int64_t n2GS1S2;
    int64_t gS1;
    int64_t n2GS1;
    int64_t gD;
    int64_t n2D;
    int64_t bN2D;
    int64_t n2G;
    int64_t n2GD;
    int64_t bN2G;
    int64_t bN2GD;
    int64_t gS2;
    int64_t n2GS2;
    int64_t biN2G;
    int64_t biN2GoIdx;
    int64_t biN2GS1D;
    int64_t biN2GS1;
    int64_t biN2GD;
    int64_t biN2S2D;
    int64_t biN2D;
    int64_t biN2GS1S2;
    int64_t biN2GS2;
    int64_t bBaseSize;
    int64_t s1OuterSize;
    int64_t s1Vec2BaseSize;
    int64_t s1Vec2BaseTailSize;
    int64_t s1Vec2OuterSize;

    int64_t softmaxBufSize = 256;
    uint32_t negativeIntScalar = NEGATIVE_MIN_VAULE_FP32;
    T negativeFloatScalar;
    T positiveFloatScalar;

    // prefixNList
    uint32_t prefixNList[PREFIX_N_MAX_B];

    uint8_t attenMaskCompressMode;

    int32_t blockIdx;
    const FlashAttentionScoreGeneralTilingData *__restrict tilingData;
    int64_t boIdx;
    int64_t n2oIdx;
    int64_t goIdx;
    int64_t s1oIdx;

    TPipe *pipe;

    GlobalTensor<INPUT_T> queryGm;
    GlobalTensor<INPUT_T> keyGm;
    GlobalTensor<INPUT_T> pseGm;
    GlobalTensor<INPUT_T> valueGm;
    GlobalTensor<INPUT_T> attentionOutGm;
    GlobalTensor<float> softmaxMaxGm;
    GlobalTensor<float> softmaxSumGm;
    GlobalTensor<uint8_t> dropMaskGm;
    GlobalTensor<uint8_t> attenMaskGmInt;

    bool dropMaskUnAligned;

    int64_t pseBSize;
    int64_t attenMaskOffsetPre;
};

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline void FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::Init(
    __gm__ uint8_t* query, __gm__ uint8_t* key, __gm__ uint8_t* value, __gm__ uint8_t* pse, __gm__ uint8_t* dropMask,
    __gm__ uint8_t* paddingMask, __gm__ uint8_t* prefix, __gm__ uint8_t* attenMask, __gm__ uint8_t* softmaxMax,
    __gm__ uint8_t* softmaxSum, __gm__ uint8_t* softmaxOut, __gm__ uint8_t* attentionOut, __gm__ uint8_t* workspace,
    const FlashAttentionScoreGeneralTilingData* __restrict tiling, TPipe* tPipe)
{
    this->InitInput(query, key, value, pse, dropMask, paddingMask, prefix, attenMask, softmaxMax, softmaxSum,
                    softmaxOut, attentionOut, workspace, tiling, tPipe);  // gm设置

    this->ComputeConstexpr();
    this->CalBatchSize();
    this->InitBuffer();
    LocalTensor<T> apiTmpBuffer = this->commonTBuf.template Get<T>();
    DropOutBitModeInit(apiTmpBuffer);
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline void FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::SetTiling(const FlashAttentionScoreGeneralTilingData *__restrict tilingData)
{
    // copy base params
    this->tilingData = tilingData;
};

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline void FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::InitInput(
    __gm__ uint8_t* query, __gm__ uint8_t* key, __gm__ uint8_t* value, __gm__ uint8_t* pse, __gm__ uint8_t* dropMask,
    __gm__ uint8_t* paddingMask, __gm__ uint8_t* prefix, __gm__ uint8_t* attenMask, __gm__ uint8_t* softmaxMax,
    __gm__ uint8_t* softmaxSum, __gm__ uint8_t* softmaxOut, __gm__ uint8_t* attentionOut, __gm__ uint8_t* workspace,
    const FlashAttentionScoreGeneralTilingData* __restrict tiling, TPipe* tPipe)
{
    this->blockIdx = GetBlockIdx();
    this->pipe = tPipe;
    this->SetTiling(tiling);

    this->queryGm.SetGlobalBuffer((__gm__ INPUT_T*)query);
    this->keyGm.SetGlobalBuffer((__gm__ INPUT_T*)key);
    this->valueGm.SetGlobalBuffer((__gm__ INPUT_T*)value);
    this->pseGm.SetGlobalBuffer((__gm__ INPUT_T*)pse);
    this->dropMaskGm.SetGlobalBuffer((__gm__ uint8_t*)dropMask);
    this->attenMaskGmInt.SetGlobalBuffer((__gm__ uint8_t*)attenMask);
    this->softmaxMaxGm.SetGlobalBuffer((__gm__ float*)softmaxMax);
    this->softmaxSumGm.SetGlobalBuffer((__gm__ float*)softmaxSum);
    this->attentionOutGm.SetGlobalBuffer((__gm__ INPUT_T*)attentionOut);

    int64_t mm1ResultSize = this->tilingData->coreParams.bBaseSize * this->tilingData->inputParams.n2Size *
                            this->tilingData->inputParams.gSize * this->tilingData->inputParams.s1Size *
                            this->tilingData->coreParams.s2BaseSize;
    int64_t mm2ResultSize = this->tilingData->coreParams.bBaseSize * this->tilingData->inputParams.n2Size *
                            this->tilingData->inputParams.gSize * this->tilingData->inputParams.s1Size *
                            this->tilingData->coreParams.dBaseSize;

    // 512对齐
    int64_t mm1Offset = CeilDiv(mm1ResultSize, 128) * 128 * sizeof(T);
    int64_t mm2Offset = CeilDiv(mm2ResultSize, 128) * 128 * sizeof(T);

    int64_t mm1TotalOffset =  mm1Offset * 2;
    int64_t mm2TotalOffset =  mm2Offset * 2;
    uint64_t dropmaskWorkSpaceLen = this->tilingData->dropmaskParams.shapeTotalSize;
    dropmaskWorkSpaceLen = CeilDiv(dropmaskWorkSpaceLen, 512) * 512;
    uint64_t mm1ResPingAddr = (dropmaskWorkSpaceLen +
                               this->blockIdx * (mm1TotalOffset + mm2TotalOffset)) / sizeof(T);
    uint64_t mm1ResPongAddr = mm1ResPingAddr + mm1Offset / sizeof(T);
    uint64_t mm2ResPingAddr = mm1ResPongAddr + mm1Offset / sizeof(T);
    uint64_t mm2ResPongAddr = mm2ResPingAddr + mm2Offset / sizeof(T);

    // bmm1Result， 占用mm1Offset空间
    this->mm1ResPing.SetGlobalBuffer((__gm__ T*)workspace + mm1ResPingAddr);
    this->mm1ResPong.SetGlobalBuffer((__gm__ T*)workspace + mm1ResPongAddr);

    // stage1Result， 占用mm1Offset空间
    this->stage1ResPing.SetGlobalBuffer((__gm__ INPUT_T*)workspace + (mm1ResPingAddr * 2));
    this->stage1ResPong.SetGlobalBuffer((__gm__ INPUT_T*)workspace + (mm1ResPongAddr * 2));

    this->mm2ResPing.SetGlobalBuffer((__gm__ T*)workspace + mm2ResPingAddr);
    this->mm2ResPong.SetGlobalBuffer((__gm__ T*)workspace + mm2ResPongAddr);

    // dropout workspace
    dropoutWorkspaceGm.SetGlobalBuffer((__gm__ uint8_t*)workspace);

    this->dropMaskUnAligned = this->tilingData->inputParams.needDropMaskOp == 1;
    if constexpr (IsSameType<T, half>::value) {
        this->negativeIntScalar = NEGATIVE_MIN_VAULE_FP16;
    }
    GetExtremeValue(this->negativeFloatScalar, this->positiveFloatScalar);

    if (this->attenMaskCompressMode ==
            static_cast<uint8_t>(AttenMaskCompressMode::PREFIX_MODE)) {
        uint16_t minN = (this->tilingData->inputParams.bSize > PREFIX_N_MAX_B) ?
                        PREFIX_N_MAX_B : this->tilingData->inputParams.bSize;
        for (uint16_t i = 0; i < minN; ++i) {
            this->prefixNList[i] = ((__gm__ uint64_t *)prefix)[i];
        }
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline void FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::ComputeConstexpr()
{
    // 计算轴的乘积
    this->s1S2 = this->tilingData->inputParams.s1Size * this->tilingData->inputParams.s2Size;
    this->gS1D = this->tilingData->inputParams.gSize * this->tilingData->inputParams.s1Size *
    this->tilingData->inputParams.dSize;
    this->dBaseSize = this->tilingData->coreParams.dBaseSize;
    this->s1Size = this->tilingData->inputParams.s1Size;
    this->s2Size = this->tilingData->inputParams.s2Size;
    this->dSize = this->tilingData->inputParams.dSize;
    this->bBaseSize = this->tilingData->coreParams.bBaseSize;
    this->gS1S2 = this->tilingData->inputParams.gSize * this->s1S2;
    this->gD = this->tilingData->inputParams.gSize * this->tilingData->inputParams.dSize;
    this->gS1 = this->tilingData->inputParams.gSize * this->tilingData->inputParams.s1Size;
    this->gS2 = this->tilingData->inputParams.gSize * this->tilingData->inputParams.s2Size;
    this->n2D = this->tilingData->inputParams.n2Size * this->tilingData->inputParams.dSize;
    this->n2G = this->tilingData->inputParams.n2Size * this->tilingData->inputParams.gSize;
    this->n2GD = this->tilingData->inputParams.n2Size * this->gD;
    this->bN2G = this->tilingData->inputParams.bSize * this->tilingData->inputParams.n2Size *
                 this->tilingData->inputParams.gSize;
    this->bN2GD = this->bN2G * this->dSize;
    this->n2GS1D = this->tilingData->inputParams.n2Size * this->gS1D;
    this->n2S2D = this->tilingData->inputParams.n2Size * this->tilingData->inputParams.s2Size *
                  this->tilingData->inputParams.dSize;
    this->n2GS1S2 = this->tilingData->inputParams.n2Size * this->gS1S2;
    this->n2GS1 = this->tilingData->inputParams.n2Size * this->gS1;
    this->n2GS2 = this->tilingData->inputParams.n2Size * this->gS2;
    this->s1BaseSize = this->tilingData->coreParams.s1BaseSize;
    this->s2BaseSize = this->tilingData->coreParams.s2BaseSize;  // 16对齐的S2
    this->pseBSize = this->tilingData->inputParams.pseBSize;
    this->attenMaskCompressMode = this->tilingData->inputParams.attenMaskCompressMode;
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline void FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::RefreshConstexpr()
{
    if (this->blockIdx == this->tilingData->multiCoreParams.coreNum - 1 &&
        this->boIdx == this->tilingData->coreParams.bOuterSize - 1) {
        this->bBaseSize = this->tilingData->coreParams.bBaseTailSize;
    }
    this->biN2G = this->bBaseSize * this->tilingData->inputParams.n2Size * this->tilingData->inputParams.gSize;
    this->biN2GS1D = this->bBaseSize * this->n2GS1D;
    this->biN2GD = this->bBaseSize * this->n2GD;
    this->biN2S2D = this->bBaseSize * this->n2S2D;
    this->biN2D = this->bBaseSize * this->n2D;
    this->biN2GS1S2 = this->bBaseSize * this->n2GS1S2;
    this->biN2GS2 = this->bBaseSize * this->n2GS2;
    this->biN2GS1 = this->bBaseSize * this->n2GS1;
    this->s1OuterSize = this->tilingData->coreParams.s1OuterSize;
    this->s1BaseSize = this->tilingData->coreParams.s1BaseSize;
    this->s1BaseTailSize = this->tilingData->coreParams.s1BaseTailSize;
    this->s1Vec2BaseSize = this->tilingData->coreParams.s1Vec2BaseSize;
    this->s1Vec2BaseTailSize = this->tilingData->coreParams.s1Vec2BaseTailSize;
    this->s1Vec2OuterSize = this->tilingData->coreParams.s1Vec2OuterSize;
    this->tensorABatchSize = this->bBaseSize * this->tilingData->inputParams.n2Size *
                             this->tilingData->inputParams.gSize;
    this->tensorBBatchSize = this->bBaseSize * this->tilingData->inputParams.n2Size;
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline void FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::CalBatchSize()
{
    this->tensorABatchSize = this->bBaseSize * this->tilingData->inputParams.n2Size *
                             this->tilingData->inputParams.gSize;
    this->tensorBBatchSize = this->bBaseSize * this->tilingData->inputParams.n2Size;
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline void FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::InitBuffer()
{
    this->pipe->InitBuffer(this->maskTBufPing, 11 * 1024);                                       // 可以给attenmask 11k
    this->pipe->InitBuffer(this->maskTBufPong, 11 * 1024);                                       // 可以给dropoutmask 11k
    this->pipe->InitBuffer(this->pseTBuf, 8 * 1024 * sizeof(INPUT_T));                          // pse 16k

    this->pipe->InitBuffer(this->stage1PingBuf, 8 * 1024 * sizeof(T));                          // t.a 32k
    this->pipe->InitBuffer(this->stage1PongBuf, 8 * 1024 * sizeof(T));                          // i.a 32k
    this->pipe->InitBuffer(this->commonTBuf, 64 * 128 * sizeof(T));                             // t.b 32k

    this->pipe->InitBuffer(this->softmaxSumPingBuf, this->softmaxBufSize * blockBytes);   // 8k max
    this->pipe->InitBuffer(this->softmaxSumPongBuf, this->softmaxBufSize * blockBytes);   // 8k max
    this->pipe->InitBuffer(this->softmaxMaxPingBuf, this->softmaxBufSize * blockBytes);   // 8k max
    this->pipe->InitBuffer(this->softmaxMaxPongBuf, this->softmaxBufSize * blockBytes);   // 8k max
    this->pipe->InitBuffer(this->softmaxExpBuf, this->softmaxBufSize * blockBytes);       // 8k max

    // 存放vector的输出，为Cast成fp16之后的结果
    this->pipe->InitBuffer(this->vecOut, 8 * 1024 * sizeof(INPUT_T));
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline void FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::Process()
{
    int64_t multiCoreInnerOffset = this->blockIdx * this->tilingData->multiCoreParams.splitFactorSize;
    int64_t multiCoreInnerLimit = multiCoreInnerOffset + this->tilingData->multiCoreParams.splitFactorSize;
    if (this->tilingData->multiCoreParams.totalSize < multiCoreInnerLimit) {
        multiCoreInnerLimit = this->tilingData->multiCoreParams.totalSize;
    }

    SplitBExtraInfo extraInfo[3];
    int64_t taskId = 0;
    event_t eventIdMte3ToMte2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_MTE2));
    for (int64_t multiCoreInnerIdx = multiCoreInnerOffset; multiCoreInnerIdx < multiCoreInnerLimit;
         multiCoreInnerIdx++) {
        this->boIdx = multiCoreInnerIdx;
        RefreshConstexpr();
        if (taskId > 0) {
            WaitBmm1Result();
        }
        this->SetExtraInfo(extraInfo[taskId % 3], taskId, multiCoreInnerIdx);
        this->IterateBmm1(extraInfo[taskId % 3]);
        if (taskId > 0) {
            this->ProcessVec1(extraInfo[(taskId + 2) % 3]);
            SetFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
        }
        if (taskId > 1) {
            WaitBmm2Result();
        }
        if (taskId > 0) {
            WaitFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
            this->IterateBmm2(extraInfo[(taskId + 2) % 3]);
        }
        if (taskId > 1) {
            this->ProcessVec2(extraInfo[(taskId + 1) % 3]);
        }
        taskId++;
    }
    if (taskId >= 1) {
        WaitBmm1Result();
        this->ProcessVec1(extraInfo[(taskId + 2) % 3]);
        SetFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
        if (taskId > 1) {
            WaitBmm2Result();
        }
        WaitFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
        this->IterateBmm2(extraInfo[(taskId + 2) % 3]);
        if (taskId > 1) {
            this->ProcessVec2(extraInfo[(taskId + 1) % 3]);
        }
    }
    taskId++;
    if (taskId >= 2) {
        WaitBmm2Result();
        this->ProcessVec2(extraInfo[(taskId + 1) % 3]);
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline void FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::WaitBmm1Result()
{
    bmm1.WaitIterateBatch();
    bmm1.End();
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline void FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::SetExtraInfo(
    SplitBExtraInfo& extraInfo, int64_t taskId, int64_t multiCoreInnerIdx)
{
    extraInfo.boIdx = multiCoreInnerIdx;
    extraInfo.taskId = taskId;
    extraInfo.vecS1BaseSize = this->s1BaseSize;
    extraInfo.vecS1TailSize = this->s1BaseTailSize;
    extraInfo.s2AlignSize = this->s2BaseSize;
    extraInfo.s2AlignBlockSize = CeilDiv(this->s2BaseSize, blockBytes) * blockBytes;
    extraInfo.s1Vec2BaseSize = this->s1Vec2BaseSize;
    extraInfo.s1Vec2BaseTailSize = this->s1Vec2BaseTailSize;
    extraInfo.s1Vec2OuterSize = this->s1Vec2OuterSize;
    extraInfo.softmaxCopyOutLimit = this->softmaxBufSize / extraInfo.vecS1BaseSize;
    extraInfo.softmaxCopyOutSize = Min(this->s1OuterSize, extraInfo.softmaxCopyOutLimit) * this->s1BaseSize;
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline void FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::IterateBmm1(SplitBExtraInfo& extraInfo)
{
    int64_t qCoreOffset = this->ComputeQCoreOffset(extraInfo);
    int64_t kvCoreOffset = this->ComputeKVCoreOffset(extraInfo);
    extraInfo.qCoreOffset = qCoreOffset;
    this->bmm1.SetTensorA(this->queryGm[qCoreOffset]);
    this->bmm1.SetTensorB(this->keyGm[kvCoreOffset], true);
    if (extraInfo.taskId % 2 == 0) {
        this->bmm1.template IterateBatch<false, true>(this->mm1ResPing, this->tensorABatchSize, this->tensorBBatchSize,
                                                      false);
    } else {
        this->bmm1.template IterateBatch<false, true>(this->mm1ResPong, this->tensorABatchSize, this->tensorBBatchSize,
                                                      false);
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline void FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::WaitBmm2Result()
{
    bmm2.WaitIterateBatch();
    bmm2.End();
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline void FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::IterateBmm2(SplitBExtraInfo& extraInfo)
{
    int64_t qCoreOffset = this->ComputeQCoreOffset(extraInfo);
    int64_t kvCoreOffset = this->ComputeKVCoreOffset(extraInfo);

    if (extraInfo.taskId % 2 == 0) {
        this->bmm2.SetTensorA(this->stage1ResPing);
    } else {
        this->bmm2.SetTensorA(this->stage1ResPong);
    }
    this->bmm2.SetTensorB(this->valueGm[kvCoreOffset]);
    if (extraInfo.taskId % 2 == 0) {
        this->bmm2.template IterateBatch<false, true>(this->mm2ResPing, this->tensorABatchSize, this->tensorBBatchSize,
                                                      false);
    } else {
        this->bmm2.template IterateBatch<false, true>(this->mm2ResPong, this->tensorABatchSize, this->tensorBBatchSize,
                                                      false);
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline int64_t FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::ComputeQCoreOffset(SplitBExtraInfo& extraInfo)
{
    // 计算gm上的offset
    int64_t qBOffset = 0;

    if constexpr (layOutType == LayOutTypeEnum::LAYOUT_BSH) {
        // BSH/BNGS1S2
        qBOffset = extraInfo.boIdx * this->biN2GS1D;
    } else if constexpr (layOutType == LayOutTypeEnum::LAYOUT_SBH) {
        // SBH/SBNGD
        qBOffset = extraInfo.boIdx * this->biN2GD;
    } else if constexpr (layOutType == LayOutTypeEnum::LAYOUT_BNSD) {
        // BNGSD
        qBOffset = extraInfo.boIdx * this->biN2GS1D;
    }

    return qBOffset;
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline int64_t FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::ComputeKVCoreOffset(SplitBExtraInfo& extraInfo)
{
    // 计算gm上的offset
    int64_t kvBOffset = 0;

    if constexpr (layOutType == LayOutTypeEnum::LAYOUT_BSH) {
        // BSH/BSND
        kvBOffset = extraInfo.boIdx * this->biN2S2D;
    } else if constexpr (layOutType == LayOutTypeEnum::LAYOUT_SBH) {
        // SBH/SBND
        kvBOffset = extraInfo.boIdx * this->biN2D;
    } else if constexpr (layOutType == LayOutTypeEnum::LAYOUT_BNSD) {
        // BNSD
        kvBOffset = extraInfo.boIdx * this->biN2S2D;
    }

    return kvBOffset;
}
template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline void FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::GetBmm1Result(SplitBExtraInfo &extraInfo, LocalTensor<T> &bmm1ResUb, int64_t loopIdx)
{
    int32_t dtypeSize = sizeof(T);
    int32_t s2Align8 = (this->s2Size + 7) / 8 * 8;
    if (extraInfo.s2AlignSize == this->s2Size) {
        // 16对齐场景，直接DataCopy提升性能
        if (extraInfo.taskId % 2 == 0) {
            DataCopy(bmm1ResUb, this->mm1ResPing[extraInfo.biN2GoIdx * this->s1Size * this->s2Size + loopIdx *
                     this->s1BaseSize * this->s2Size], extraInfo.vecS1BaseSize * this->s2Size);
        } else {
            DataCopy(bmm1ResUb, this->mm1ResPong[extraInfo.biN2GoIdx * this->s1Size * this->s2Size + loopIdx *
                     this->s1BaseSize * this->s2Size], extraInfo.vecS1BaseSize * this->s2Size);
        }
    } else {
        DataCopyParams dataCopyParams;
        DataCopyPadParams dataCopyPadParams;
        dataCopyParams.blockCount = extraInfo.vecS1BaseSize;
        dataCopyParams.dstStride = 0;
        dataCopyParams.srcStride = 0;
        dataCopyParams.blockLen = this->s2Size * dtypeSize;
        dataCopyPadParams.rightPadding = extraInfo.s2AlignSize - this->s2Size;
        dataCopyPadParams.paddingValue = 0;
        if (dataCopyPadParams.rightPadding > blockSize) {
            // 8对齐场景，内部vector需要16对齐，我们在data copy的时候需要手动补0
            dataCopyPadParams.rightPadding -= blockSize;
            dataCopyParams.dstStride = 1;
            Duplicate<T>(bmm1ResUb[s2Align8], 0, blockSize, extraInfo.vecS1BaseSize, 0, extraInfo.s2AlignSize *
                         sizeof(T) / blockBytes);
        }
        if (extraInfo.taskId % 2 == 0) {
            DataCopyPad(bmm1ResUb, this->mm1ResPing[extraInfo.biN2GoIdx * this->s1Size * this->s2Size + loopIdx *
                        this->s1BaseSize * this->s2Size], dataCopyParams, dataCopyPadParams);
        } else {
            DataCopyPad(bmm1ResUb, this->mm1ResPong[extraInfo.biN2GoIdx * this->s1Size * this->s2Size + loopIdx *
                        this->s1BaseSize * this->s2Size], dataCopyParams, dataCopyPadParams);
        }
    }
    uint32_t bmm1ResUbShape[] = {static_cast<uint32_t>(extraInfo.vecS1BaseSize),
                                 static_cast<uint32_t>(extraInfo.s2AlignSize)};
    bmm1ResUb.SetShapeInfo(ShapeInfo(2, bmm1ResUbShape, DataFormat::ND));
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline void FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::ComputePse(LocalTensor<T> &pseUb, LocalTensor<T> &bmm1ResUb, SplitBExtraInfo &extraInfo)
{
    if constexpr (hasPse == true) {
        int32_t computeSize = this->tilingData->inputParams.pseShapeType == pseS1S2 ? extraInfo.vecS1BaseSize *
                              extraInfo.s2AlignSize : extraInfo.s2AlignSize;
        pipe_barrier(PIPE_V);
        PseBroadcastAdd(extraInfo.vecS1BaseSize, extraInfo.s2AlignSize, computeSize, pseUb, bmm1ResUb,
                        this->tilingData->inputParams.pseShapeType);
        return;
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline void FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::ProcessVec1(SplitBExtraInfo &extraInfo)
{
    LocalTensor<T> stage1PingTensor = this->stage1PingBuf.template Get<T>(); // t.a 32k
    LocalTensor<T> stage1PongTensor = this->stage1PongBuf.template Get<T>(); // i.a 32k
    event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->AllocEventID<HardEvent::MTE2_V>());
    event_t eventIdVToMte2A = static_cast<event_t>(GetTPipePtr()->AllocEventID<HardEvent::V_MTE2>());
    event_t eventIdVToMte2B = static_cast<event_t>(GetTPipePtr()->AllocEventID<HardEvent::V_MTE2>());
    event_t eventIdMte3ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_V));
    event_t eventIdVToMte3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
    uint32_t loopIdxNew = 0;
    for (uint32_t biN2GoIdx = 0; biN2GoIdx < this->biN2G; biN2GoIdx++) {
        extraInfo.biN2GoIdx = biN2GoIdx;
        extraInfo.softmaxOutOffset = 0;
        for (uint32_t loopIdx = 0; loopIdx < this->s1OuterSize; loopIdx++, loopIdxNew++) {
            extraInfo.vecS1BaseSize = this->s1BaseSize;
            if (loopIdx == this->s1OuterSize - 1) {
                extraInfo.vecS1BaseSize = extraInfo.vecS1TailSize;
            }
            extraInfo.s1oIdx = loopIdx;

            if (loopIdxNew > 0) {
                WaitFlag<HardEvent::V_MTE2>(eventIdVToMte2A);
            }
            this->CopyInAndCastPse(stage1PingTensor, extraInfo, loopIdx);
            this->GetBmm1Result(extraInfo, stage1PongTensor, loopIdx);
            SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
            WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
            this->ComputePse(stage1PingTensor, stage1PongTensor, extraInfo);

            this->CopyInAttenMask(extraInfo, -1);
            pipe_barrier(PIPE_V);

            Muls(stage1PingTensor, stage1PongTensor, static_cast<T>(this->tilingData->inputParams.scaleValue),
                 extraInfo.vecS1BaseSize * extraInfo.s2AlignSize);
            SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
            WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
            if (this->attenMaskCompressMode !=
                    static_cast<uint8_t>(AttenMaskCompressMode::PREFIX_MODE)) {
                this->ComputeAttenMask(extraInfo, stage1PingTensor, 0);
            }
            if (this->attenMaskCompressMode ==
                static_cast<uint8_t>(AttenMaskCompressMode::BAND_MODE)) {
                event_t eventIdVToMte2Tmp = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
                SetFlag<HardEvent::V_MTE2>(eventIdVToMte2Tmp);
                WaitFlag<HardEvent::V_MTE2>(eventIdVToMte2Tmp);
                this->CopyInAttenMask(extraInfo, this->attenMaskOffsetPre);
                event_t eventIdMte2ToVTmp = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
                SetFlag<HardEvent::MTE2_V>(eventIdMte2ToVTmp);
                WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToVTmp);
                this->ComputeAttenMask(extraInfo, stage1PingTensor, 1);
            }
            if (this->attenMaskCompressMode ==
                static_cast<uint8_t>(AttenMaskCompressMode::PREFIX_MODE)) {
                event_t eventIdMte3ToMte2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_MTE2));
                SetFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
                WaitFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
                this->CopyInAttenMask(extraInfo, this->attenMaskOffsetPre);
                int32_t maskNum = extraInfo.vecS1BaseSize * extraInfo.s2AlignBlockSize / 2; // 除2数据量按照uint16类型折半
                event_t eventIdMte2ToVTmp = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
                SetFlag<HardEvent::MTE2_V>(eventIdMte2ToVTmp);
                WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToVTmp);
                LocalTensor<uint8_t> attenMaskCasualUb = this->maskTBufPing.template Get<uint8_t>();
                LocalTensor<uint8_t> attenMaskPrefixUb = this->pseTBuf.template Get<uint8_t>();
                auto attenMaskCasualTmp = attenMaskCasualUb.ReinterpretCast<uint16_t>();
                auto attenMaskPrefixUbTmp = attenMaskPrefixUb.ReinterpretCast<uint16_t>();
                And(attenMaskCasualTmp, attenMaskCasualTmp, attenMaskPrefixUbTmp, maskNum);
                pipe_barrier(PIPE_V);
                attenMaskCasualUb = attenMaskCasualTmp.ReinterpretCast<uint8_t>();
                this->ComputeAttenMask(extraInfo, stage1PingTensor, 0);
            }
            if (loopIdxNew < this->biN2G * this->s1OuterSize - 1) {
                SetFlag<HardEvent::V_MTE2>(eventIdVToMte2A);
            }

            if (loopIdxNew > 0) {
                WaitFlag<HardEvent::V_MTE2>(eventIdVToMte2B);
            }

            this->CopyInDropMask(extraInfo);
            this->SoftMaxCompute(extraInfo, stage1PingTensor, loopIdx);
            this->ComputeDropMask(extraInfo, stage1PingTensor, eventIdMte2ToV);

            if (loopIdxNew < this->biN2G * this->s1OuterSize - 1) {
                SetFlag<HardEvent::V_MTE2>(eventIdVToMte2B);
            }
            pipe_barrier(PIPE_V);

            if (loopIdxNew > 0) {
                WaitFlag<HardEvent::MTE3_V>(eventIdMte3ToV);
            }

            if constexpr (!IsSameType<T, INPUT_T>::value) {
                LocalTensor<INPUT_T> stage1CastTensor = this->vecOut.template Get<INPUT_T>();
                Cast(stage1CastTensor, stage1PingTensor, RoundMode::CAST_ROUND, extraInfo.vecS1BaseSize *
                     extraInfo.s2AlignSize);
                SetFlag<HardEvent::V_MTE3>(eventIdVToMte3);
                WaitFlag<HardEvent::V_MTE3>(eventIdVToMte3);
                DataCopyParams dataCopyParams;
                DataCopyPadParams dataCopyPadParams;
                dataCopyParams.blockCount = extraInfo.vecS1BaseSize;
                dataCopyParams.dstStride = 0;
                dataCopyParams.srcStride = 0;
                dataCopyParams.blockLen = this->s2Size * sizeof(INPUT_T);
                if (extraInfo.taskId % 2 == 0) {
                    DataCopyPad(this->stage1ResPing[biN2GoIdx * this->s1Size * this->s2Size  + loopIdx *
                                this->s1BaseSize * this->s2Size], stage1CastTensor, dataCopyParams);
                } else {
                    DataCopyPad(this->stage1ResPong[biN2GoIdx * this->s1Size * this->s2Size + loopIdx *
                                this->s1BaseSize * this->s2Size], stage1CastTensor, dataCopyParams);
                }
            }
            if (loopIdxNew < this->biN2G * this->s1OuterSize - 1) {
                SetFlag<HardEvent::MTE3_V>(eventIdMte3ToV);
            }
        }
    }
    GetTPipePtr()->ReleaseEventID<HardEvent::MTE2_V>(eventIdMte2ToV);
    GetTPipePtr()->ReleaseEventID<HardEvent::V_MTE2>(eventIdVToMte2A);
    GetTPipePtr()->ReleaseEventID<HardEvent::V_MTE2>(eventIdVToMte2B);
    return;
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline void FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::ProcessVec2(SplitBExtraInfo &extraInfo)
{
    event_t eventIdMte3ToMte2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_MTE2));
    for (uint32_t biN2GoIdx = 0; biN2GoIdx < this->biN2G; ++biN2GoIdx) {
        extraInfo.biN2GoIdx = biN2GoIdx;
        // 获取缓存bmm2的计算结果
        LocalTensor<T> bmm2ResPingUb = this->stage1PingBuf.template Get<T>();
        LocalTensor<T> bmm2ResPongUb = this->stage1PongBuf.template Get<T>();
        LocalTensor<T> softmaxSumPingUb = this->pseTBuf.template Get<float>();
        LocalTensor<T> softmaxSumPongUb = this->vecOut.template Get<float>();

        // pseTBuf和vecOut同时作为Softmax输入和最后attention out输出的ub
        LocalTensor<INPUT_T> attentionOutPingUb = this->pseTBuf.template Get<INPUT_T>();
        LocalTensor<INPUT_T> attentionOutPongUb = this->vecOut.template Get<INPUT_T>();

        int64_t bOffset = extraInfo.boIdx * this->biN2GS1 * fp32BaseSize;
        int64_t biN2GOffset = extraInfo.biN2GoIdx * this->s1Size * this->dSize;
        int64_t biN2GSumOffset = extraInfo.biN2GoIdx * this->s1Size * fp32BaseSize;
        int64_t sumGmOffset = bOffset + biN2GSumOffset;
        for (int64_t s1oIdx = 0; s1oIdx < extraInfo.s1Vec2OuterSize; ++s1oIdx) {
            int sumGmRealOffset = 0;
            extraInfo.s1Vec2BaseSize = this->s1Vec2BaseSize;
            if (s1oIdx == extraInfo.s1Vec2OuterSize - 1) {
                extraInfo.s1Vec2BaseSize = extraInfo.s1Vec2BaseTailSize;
            }
            int64_t sumInnerSize = extraInfo.s1Vec2BaseSize * fp32BaseSize;
            int64_t mm2ResCalcSize = extraInfo.s1Vec2BaseSize * this->dSize;
            int64_t mm2ResOffset = biN2GOffset + s1oIdx * this->s1Vec2BaseSize * this->dSize;
            SetFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
            WaitFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
            sumGmRealOffset = sumGmOffset + s1oIdx * this->s1Vec2BaseSize * fp32BaseSize;
            if (extraInfo.taskId % 2 == 0) {
                DataCopy(bmm2ResPingUb, this->mm2ResPing[mm2ResOffset], mm2ResCalcSize);
                DataCopy(softmaxSumPingUb, this->softmaxSumGm[sumGmRealOffset], sumInnerSize);

                Bmm2ResultDiv(extraInfo, s1oIdx, bmm2ResPingUb, softmaxSumPingUb, extraInfo.s1Vec2BaseSize);
                Bmm2DataCopyOut(extraInfo, s1oIdx, bmm2ResPingUb, attentionOutPingUb);
            } else {
                DataCopy(bmm2ResPongUb, this->mm2ResPong[mm2ResOffset], mm2ResCalcSize);
                DataCopy(softmaxSumPongUb, this->softmaxSumGm[sumGmRealOffset], sumInnerSize);

                Bmm2ResultDiv(extraInfo, s1oIdx, bmm2ResPongUb, softmaxSumPongUb, extraInfo.s1Vec2BaseSize);
                Bmm2DataCopyOut(extraInfo, s1oIdx, bmm2ResPongUb, attentionOutPongUb);
            }
        }
    }
    SetFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
    WaitFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
    return;
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline void FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::Bmm2ResultDiv(SplitBExtraInfo &extraInfo, int64_t vec2S1Idx, LocalTensor<T> &bmm2Res,
    LocalTensor<T> &sumTensor, int64_t vec2S1BaseSize)
{
    BinaryRepeatParams repeatParams;
    repeatParams.src0BlkStride = 1;
    repeatParams.src0RepStride = this->dSize / blockSize;
    repeatParams.src1BlkStride = 0;
    repeatParams.src1RepStride = 1;
    repeatParams.dstRepStride = this->dSize / blockSize;
    int32_t s1OuterLoop = vec2S1BaseSize / repeatMaxTimes;
    int32_t s1OuterRemain = vec2S1BaseSize % repeatMaxTimes;
    pipe_barrier(PIPE_V);
    if constexpr (IsSameType<T, half>::value) {
        LocalTensor<T> sumCastTensor;
        if (extraInfo.taskId % 2 == 0) {
            sumCastTensor = this->pseTBuf.template Get<T>();
        } else {
            sumCastTensor = this->vecOut.template Get<T>();
        }
        Cast(sumCastTensor, sumTensor, RoundMode::CAST_ROUND, sumTensor.GetSize());
        pipe_barrier(PIPE_V);
        for (int32_t i = 0; i < s1OuterLoop; i++) {
            int32_t innerLoop = this->dSize / repeatMaxSize;
            int32_t innerRemain = this->dSize % repeatMaxSize;
            int64_t s1OuterOffset = i * repeatMaxTimes * this->dSize;
            int64_t sumOffset = i * repeatMaxTimes * 8;
            for (int32_t j = 0; j < innerLoop; j++) {
                Div(bmm2Res[s1OuterOffset + j * repeatMaxSize], bmm2Res[s1OuterOffset + j * repeatMaxSize],
                    sumCastTensor[sumOffset], repeatMaxSize, repeatMaxTimes, repeatParams);
            }
            if (likely(innerRemain)) {
                Div(bmm2Res[s1OuterOffset + innerLoop * repeatMaxSize], bmm2Res[s1OuterOffset + innerLoop * repeatMaxSize],
                    sumCastTensor[sumOffset], innerRemain, repeatMaxTimes, repeatParams);
            }
        }
        if (likely(s1OuterRemain)) {
            int32_t innerLoop = this->dSize / repeatMaxSize;
            int32_t innerRemain = this->dSize % repeatMaxSize;
            int64_t s1OuterOffset = s1OuterLoop * repeatMaxTimes * this->dSize;
            int64_t sumOffset = s1OuterLoop * repeatMaxTimes * 8;
            for (int32_t j = 0; j < innerLoop; j++) {
                Div(bmm2Res[s1OuterOffset + j * repeatMaxSize], bmm2Res[s1OuterOffset + j * repeatMaxSize],
                    sumCastTensor[sumOffset], repeatMaxSize, s1OuterRemain, repeatParams);
            }
            if (likely(innerRemain)) {
                Div(bmm2Res[s1OuterOffset + innerLoop * repeatMaxSize], bmm2Res[s1OuterOffset + innerLoop * repeatMaxSize],
                    sumCastTensor[sumOffset], innerRemain, s1OuterRemain, repeatParams);
            }
        }
    } else {
        event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
        SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
        WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
        for (int32_t i = 0; i < s1OuterLoop; i++) {
            int32_t innerLoop = this->dSize / repeatMaxSize;
            int32_t innerRemain = this->dSize % repeatMaxSize;
            int64_t s1OuterOffset = i * repeatMaxTimes * this->dSize;
            int64_t sumOffset = i * repeatMaxTimes * 8;
            for (int32_t j = 0; j < innerLoop; j++) {
                Div(bmm2Res[s1OuterOffset + j * repeatMaxSize], bmm2Res[s1OuterOffset + j * repeatMaxSize],
                    sumTensor[sumOffset], repeatMaxSize, repeatMaxTimes, repeatParams);
    }
            if (likely(innerRemain)) {
                Div(bmm2Res[s1OuterOffset + innerLoop * repeatMaxSize], bmm2Res[s1OuterOffset + innerLoop * repeatMaxSize],
                    sumTensor[sumOffset], innerRemain, repeatMaxTimes, repeatParams);
            }
        }
        if (likely(s1OuterRemain)) {
            int32_t innerLoop = this->dSize / repeatMaxSize;
            int32_t innerRemain = this->dSize % repeatMaxSize;
            int64_t s1OuterOffset = s1OuterLoop * repeatMaxTimes * this->dSize;
            int64_t sumOffset = s1OuterLoop * repeatMaxTimes * 8;
            for (int32_t j = 0; j < innerLoop; j++) {
                Div(bmm2Res[s1OuterOffset + j * repeatMaxSize], bmm2Res[s1OuterOffset + j * repeatMaxSize],
                    sumTensor[sumOffset], repeatMaxSize, s1OuterRemain, repeatParams);
            }
            if (likely(innerRemain)) {
                Div(bmm2Res[s1OuterOffset + innerLoop * repeatMaxSize], bmm2Res[s1OuterOffset + innerLoop * repeatMaxSize],
                    sumTensor[sumOffset], innerRemain, s1OuterRemain, repeatParams);
            }
        }
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop,  typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline void FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::Bmm2DataCopyOut(
    SplitBExtraInfo &extraInfo, int64_t vec2S1Idx, LocalTensor<T> &bmm2Res, LocalTensor<INPUT_T> &attentionOut)
{
    uint32_t calcSize = bmm2Res.GetSize();
    pipe_barrier(PIPE_V);

    if constexpr (!IsSameType<INPUT_T, T>::value) {
        Cast(attentionOut, bmm2Res, RoundMode::CAST_ROUND, calcSize);
    }

    event_t eventIdVToMte3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
    SetFlag<HardEvent::V_MTE3>(eventIdVToMte3);
    WaitFlag<HardEvent::V_MTE3>(eventIdVToMte3);

    DataCopyParams dataCopyParams;
    dataCopyParams.blockLen = this->dSize * sizeof(INPUT_T);
    dataCopyParams.srcStride = 0;
    int64_t dstStride = 0;
    int64_t attenOutOffset = this->dSize;
    int64_t biOffset = (extraInfo.biN2GoIdx / this->n2G) * this->n2GS1D;
    int64_t n2GOffset = this->s1Size * this->dSize;
    if constexpr (layOutType == LayOutTypeEnum::LAYOUT_BSH) {
        attenOutOffset = this->n2GD;
        n2GOffset = this->dSize;
        dstStride = (this->tilingData->inputParams.n2Size * this->tilingData->inputParams.gSize - 1) * this->dSize *
                     sizeof(INPUT_T);
    } else if constexpr (layOutType == LayOutTypeEnum::LAYOUT_SBH) {
        attenOutOffset = this->bN2GD;
        n2GOffset = this->dSize;
        biOffset = (extraInfo.biN2GoIdx / this->n2G) * this->n2GD;
        dstStride = (this->tilingData->inputParams.bSize * this->tilingData->inputParams.n2Size *
                     this->tilingData->inputParams.gSize - 1) * this->dSize * sizeof(INPUT_T);
    }
    int64_t attOutBaseOffset = extraInfo.qCoreOffset + biOffset + (extraInfo.biN2GoIdx % this->n2G) *
                               n2GOffset + vec2S1Idx * this->s1Vec2BaseSize * attenOutOffset;
    // dataCopyParams.dstStride类型定义uint16_t，65535是其最大值
    if (likely(dstStride <= 65535)) {
        dataCopyParams.blockCount = extraInfo.s1Vec2BaseSize;
        dataCopyParams.dstStride = static_cast<uint16_t>(dstStride);
        DataCopyPad(this->attentionOutGm[attOutBaseOffset], attentionOut, dataCopyParams);
    } else {
        dataCopyParams.blockCount = 1;
        dataCopyParams.dstStride = 0;
        int64_t datacopyOffset = this->dSize;
        if constexpr (layOutType == LayOutTypeEnum::LAYOUT_BSH) {
            datacopyOffset = this->n2GD;
        } else if constexpr (layOutType == LayOutTypeEnum::LAYOUT_SBH) {
            datacopyOffset = this->bN2G * this->dSize;
        }

        for (uint32_t i = 0; i < extraInfo.s1Vec2BaseSize; i++) {
            DataCopyPad(this->attentionOutGm[attOutBaseOffset + i * datacopyOffset],
                      attentionOut[i * this->dSize], dataCopyParams);
        }
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline void FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::DataCopyIn(LocalTensor<INPUT_T> &dstTensor, GlobalTensor<INPUT_T> &srcTensor, int64_t offset,  \
    int32_t s1Size, int32_t alignS2Size)
{
    int32_t dtypeSize = sizeof(INPUT_T);
    uint32_t shapeArray[] = {static_cast<uint32_t>(s1Size), static_cast<uint32_t>(alignS2Size)};
    dstTensor.SetShapeInfo(ShapeInfo(2, shapeArray, DataFormat::ND));
    dstTensor.SetSize(s1Size * alignS2Size);
    DataCopyParams dataCopyParams;
    dataCopyParams.blockCount = s1Size;
    dataCopyParams.blockLen = CeilDiv(this->s2Size * dtypeSize, blockBytes);    // 单位32B
    dataCopyParams.dstStride = 0;
    if (this->s2Size * dtypeSize % blockBytes == 0) {
        dataCopyParams.srcStride = 0;
        DataCopy(dstTensor, srcTensor[offset], dataCopyParams);
    } else {
        dataCopyParams.blockLen = this->s2Size * dtypeSize;                                 // 单位Byte
        dataCopyParams.srcStride = 0;
        dataCopyParams.dstStride = 0;
        DataCopyPadParams dataCopyPadParams;
        dataCopyPadParams.rightPadding = alignS2Size - s2Size;
        DataCopyPad(dstTensor, srcTensor[offset], dataCopyParams, dataCopyPadParams);
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline void FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::CopyInAndCastPse(LocalTensor<T> &castTensor, SplitBExtraInfo &extraInfo, int64_t loopIdx)
{
    if constexpr (hasPse == true) {
        int64_t offset = this->ComputePseOffset(extraInfo);
        LocalTensor<INPUT_T> pseUb = this->pseTBuf.template Get<INPUT_T>();

        int32_t s1Size = this->tilingData->inputParams.pseShapeType == pse1S2 ? 1 : extraInfo.vecS1BaseSize;
        this->DataCopyIn(pseUb, this->pseGm, offset, s1Size, extraInfo.s2AlignSize);

        event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
        SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
        WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
        castTensor.SetSize(extraInfo.vecS1BaseSize * extraInfo.s2AlignSize);
        if constexpr (!IsSameType<INPUT_T, T>::value) {
            Cast(castTensor, pseUb, RoundMode::CAST_NONE, s1Size * extraInfo.s2AlignSize);
        } else {
            DataCopy(castTensor, pseUb, s1Size * extraInfo.s2AlignSize);
        }
        return;
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline int64_t FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    layout>::ComputePseOffset(SplitBExtraInfo &extraInfo)
{
    if constexpr (hasPse == true) {
        int64_t bOffset = 0;
        int64_t s1Offset = 0;
        int64_t biN2GOffset = 0;
        if (this->tilingData->inputParams.pseShapeType == pseS1S2) {
            // b, n2, g, s1, s2
            bOffset = this->pseBSize == 1 ? 0 : extraInfo.boIdx * this->biN2GS1S2;
            s1Offset = extraInfo.s1oIdx * this->s1BaseSize * this->s2Size;
            biN2GOffset = extraInfo.biN2GoIdx * this->s1Size * this->s2Size;
        } else if (this->tilingData->inputParams.pseShapeType == pse1S2) {
            // b, n2, g, 1, s2
            bOffset = this->pseBSize == 1 ? 0 : extraInfo.boIdx * this->biN2GS2;
            biN2GOffset = extraInfo.biN2GoIdx * this->s2Size;
        }
        return bOffset + biN2GOffset + s1Offset;
    }
}


template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline void FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::CopyInAttenMask(SplitBExtraInfo& extraInfo, int64_t maskOffset)
{
    if constexpr (hasAtten == true) {
        LocalTensor<uint8_t> attenMaskUb = this->maskTBufPing.template Get<uint8_t>();
        if (this->attenMaskCompressMode == static_cast<uint8_t>(AttenMaskCompressMode::PREFIX_MODE) && maskOffset != -1) {
            pipe_barrier(PIPE_V);
            attenMaskUb = this->pseTBuf.template Get<uint8_t>();
        }
        if (maskOffset == -1) {
            maskOffset = this->ComputeAttenMaskOffset(extraInfo);
        }
        this->BoolCopyIn(attenMaskUb, this->attenMaskGmInt, maskOffset, extraInfo, this->s2Size,
        this->tilingData->inputParams.attenMaskS2Size);
        return;
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline int64_t FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::ComputeAttenMaskOffset(SplitBExtraInfo& extraInfo)
{
    if constexpr (hasAtten == true) {
        int64_t deltaCausalOrNext = 0;
        int64_t s1Offset = extraInfo.s1oIdx * this->s1BaseSize;
        int64_t s2Offset = 0;
        int64_t deltaN = (int64_t)this->s1Size - (int64_t)this->s2Size;
        int64_t deltaPre = 0;
        if (this->attenMaskCompressMode ==
            static_cast<uint8_t>(AttenMaskCompressMode::NO_COMPRESS_MODE)) {
            return this->ComputeOffsetForNoCompress(extraInfo);
        } else if (this->attenMaskCompressMode ==
            static_cast<uint8_t>(AttenMaskCompressMode::LEFT_UP_CAUSAL_MODE)) {
            deltaCausalOrNext = s1Offset - s2Offset;
        } else if (this->attenMaskCompressMode ==
            static_cast<uint8_t>(AttenMaskCompressMode::RIGHT_DOWN_CAUSAL_MODE)) {
            deltaCausalOrNext = s1Offset - s2Offset - deltaN;
        } else if (this->attenMaskCompressMode ==
            static_cast<uint8_t>(AttenMaskCompressMode::BAND_MODE)) {
            deltaPre = s1Offset - s2Offset - this->tilingData->inputParams.preTokens - 1;
            this->attenMaskOffsetPre = ComputeOffsetForCausal(deltaPre, this->s1BaseSize,
                                       this->s2Size, this->tilingData->inputParams.attenMaskS2Size);
            deltaCausalOrNext = s1Offset - s2Offset + this->tilingData->inputParams.nextTokens;
        } else if (this->attenMaskCompressMode ==
            static_cast<uint8_t>(AttenMaskCompressMode::PREFIX_MODE)) {
            deltaCausalOrNext = s1Offset - s2Offset - deltaN;
            deltaPre = (this->s1Size + this->prefixNList[extraInfo.boIdx] > this->s2Size) ?
                        (this->prefixNList[extraInfo.boIdx] + 1 - s2Offset) : 0;
            this->attenMaskOffsetPre = ComputeOffsetForPrefixRectangle(deltaPre, this->s2Size,
                                        this->tilingData->inputParams.attenMaskS2Size);
        } else {
            return 0;
        }
        return ComputeOffsetForCausal(deltaCausalOrNext, this->s1BaseSize, this->s2Size,
                                            this->tilingData->inputParams.attenMaskS2Size);
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline int64_t FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::ComputeOffsetForNoCompress(SplitBExtraInfo& extraInfo)
{
    if constexpr (hasAtten == true) {
        int64_t bOffset = 0;
        int64_t s1Offset = 0;
        int64_t biN2GOffset = 0;
        if (this->tilingData->inputParams.attenMaskShapeType == 0) {            // 0: (B,N2,G,S1,S2)
            bOffset = extraInfo.boIdx * this->biN2GS1S2;
            s1Offset = extraInfo.s1oIdx * this->s1BaseSize * this->s2Size;
            biN2GOffset = extraInfo.biN2GoIdx * this->s1Size * this->s2Size;
        } else if (this->tilingData->inputParams.attenMaskShapeType == 1) {     // 1: (B,1,1,S1,S2)
            bOffset = extraInfo.boIdx * this->bBaseSize * this->s1S2;
            s1Offset = extraInfo.s1oIdx * this->s1BaseSize * this->s2Size;
            biN2GOffset = extraInfo.biN2GoIdx / (this->tilingData->inputParams.n2Size * this->tilingData->inputParams.gSize)
                          * this->s1Size * this->s2Size;
        } else if (this->tilingData->inputParams.attenMaskShapeType == 2) {     // 2: (1,1,1,s1,s2)
            s1Offset = extraInfo.s1oIdx * this->s1BaseSize * this->s2Size;
        }
        return bOffset + s1Offset + biN2GOffset;
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline void FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::BoolCopyIn(LocalTensor<uint8_t> &dstTensor, GlobalTensor<uint8_t> &srcTensor, int64_t offset,  \
    SplitBExtraInfo& extraInfo, int32_t s2Size, int64_t totalS2Size)
{
    extraInfo.s2AlignBlockSize = CeilDiv(s2Size, blockBytes) * blockBytes;
    uint32_t shapeArray[] = {static_cast<uint32_t>(extraInfo.vecS1BaseSize),
                             static_cast<uint32_t>(extraInfo.s2AlignBlockSize)};
    dstTensor.SetShapeInfo(ShapeInfo(2, shapeArray, DataFormat::ND));
    dstTensor.SetSize(extraInfo.vecS1BaseSize * extraInfo.s2AlignBlockSize);
    DataCopyParams dataCopyParams;
    dataCopyParams.blockCount = extraInfo.vecS1BaseSize;
    dataCopyParams.dstStride = 0;
    if (s2Size % blockBytes == 0) {
        dataCopyParams.blockLen = extraInfo.s2AlignBlockSize / blockBytes;
        dataCopyParams.srcStride = (totalS2Size - extraInfo.s2AlignBlockSize) / blockBytes;;
        DataCopy(dstTensor, srcTensor[offset], dataCopyParams);
    } else {
        dataCopyParams.blockLen = s2Size;
        dataCopyParams.srcStride = totalS2Size - s2Size;
        DataCopyPadParams dataCopyPadParams;
        dataCopyPadParams.isPad = true;
        dataCopyPadParams.rightPadding = extraInfo.s2AlignBlockSize - s2Size;
        dataCopyPadParams.paddingValue = 1;
        DataCopyPad(dstTensor, srcTensor[offset], dataCopyParams, dataCopyPadParams);
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline void FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::ComputeAttenMask(SplitBExtraInfo &extraInfo, LocalTensor<T> &bmm1ResUb, const uint8_t maskType)
{
    if constexpr (hasAtten == true) {
        LocalTensor<uint8_t> attenMaskUb = this->maskTBufPing.template Get<uint8_t>();
        uint32_t shapeArray[] = {static_cast<uint32_t>(extraInfo.vecS1BaseSize),
        static_cast<uint32_t>(extraInfo.s2AlignBlockSize)};
        attenMaskUb.SetShapeInfo(ShapeInfo(2, shapeArray, DataFormat::ND));
        attenMaskUb.SetSize(extraInfo.vecS1BaseSize * extraInfo.s2AlignBlockSize);
        bmm1ResUb.SetSize(extraInfo.vecS1BaseSize * extraInfo.s2AlignSize);
        LocalTensor<uint8_t> apiTmpBuffer = commonTBuf.template Get<uint8_t>();
        SelectWithBytesMaskShapeInfo shapeInfo;
        shapeInfo.firstAxis = extraInfo.vecS1BaseSize;
        shapeInfo.srcLastAxis = extraInfo.s2AlignSize;
        shapeInfo.maskLastAxis = extraInfo.s2AlignBlockSize;

        if (maskType == 0) {
            SelectWithBytesMask(bmm1ResUb, bmm1ResUb, this->negativeFloatScalar, attenMaskUb, apiTmpBuffer, shapeInfo);
        } else {
            SelectWithBytesMask(bmm1ResUb, this->negativeFloatScalar, bmm1ResUb, attenMaskUb, apiTmpBuffer, shapeInfo);
        }
        return;
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline void FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    layout>::CopyInDropMask(SplitBExtraInfo& extraInfo)
{
    if constexpr (hasDrop == true) {
        // softmax的计算结果需要执行dropout操作
        int64_t dropMaskOffset = this->ComputeDropOffset(extraInfo);
        LocalTensor<uint8_t> dropMaskUb = this->maskTBufPong.template Get<uint8_t>();
        // 仅支持对齐场景，非对齐使用boolCopyIn
        uint32_t shapeArray[] = {static_cast<uint32_t>(extraInfo.vecS1BaseSize),
                                 static_cast<uint32_t>(extraInfo.s2AlignBlockSize)};
        dropMaskUb.SetShapeInfo(ShapeInfo(2, shapeArray, DataFormat::ND));
        dropMaskUb.SetSize(extraInfo.vecS1BaseSize * extraInfo.s2AlignBlockSize);

        if (this->dropMaskUnAligned) {
            DataCopyParams dataCopyParams;
            dataCopyParams.blockCount = extraInfo.vecS1BaseSize;
            dataCopyParams.blockLen = this->s2Size;
            dataCopyParams.dstStride = 0;  // gap
            dataCopyParams.dstStride = 0;

            DataCopyPadParams dataCopyPadParams;
            dataCopyPadParams.isPad = true;
            dataCopyPadParams.leftPadding = 0;
            dataCopyPadParams.rightPadding =
                (this->s2Size + blockBytes - 1) / blockBytes * blockBytes - this->s2Size;
            dataCopyPadParams.paddingValue = 0;
            DataCopyPad(dropMaskUb, dropoutWorkspaceGm[dropMaskOffset], dataCopyParams, dataCopyPadParams);
        } else {
            DataCopyPad(dropMaskUb, this->dropMaskGm[dropMaskOffset / 8],
                        {static_cast<uint16_t>(extraInfo.vecS1BaseSize), static_cast<uint16_t>(this->s2Size / 8), 0, 0},
                        {false, 0, 0, 0});
        }
        return;
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline int64_t FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
  layout>::ComputeDropOffset(SplitBExtraInfo &extraInfo)
{
    if constexpr (hasDrop == true) {
        int64_t bOffset = extraInfo.boIdx * this->biN2GS1S2;
        int64_t s1Offset = extraInfo.s1oIdx * this->s1BaseSize * this->s2Size;
        int64_t biN2GOffset = extraInfo.biN2GoIdx * this->s1Size * this->s2Size;
        return bOffset + s1Offset + biN2GOffset;
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline void FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::SoftMaxCompute(SplitBExtraInfo &extraInfo, LocalTensor<T> &srcTensor, int64_t loopIdx)
{
    uint32_t bmm1ResUbShape[] = {static_cast<uint32_t>(extraInfo.vecS1BaseSize),
                                 static_cast<uint32_t>(extraInfo.s2AlignSize)};
    uint32_t bmm1ResUbOriShape[] = {static_cast<uint32_t>(extraInfo.vecS1BaseSize),
                                    static_cast<uint32_t>(this->s2Size)};
    srcTensor.SetShapeInfo(ShapeInfo(2, bmm1ResUbShape, 2, bmm1ResUbOriShape, DataFormat::ND));

    uint32_t maxSumShape[] = {static_cast<uint32_t>(extraInfo.vecS1BaseSize),
                              static_cast<uint32_t>(fp32BaseSize)};
    LocalTensor<T> sumUb;
    LocalTensor<T> maxUb;
    int64_t sumOffset = (loopIdx % extraInfo.softmaxCopyOutLimit) * this->s1BaseSize * fp32BaseSize;
    if (extraInfo.boIdx % 2 == 0) {
        sumUb = this->softmaxSumPingBuf.template Get<T>()[sumOffset];
        maxUb = this->softmaxMaxPingBuf.template Get<T>()[sumOffset];
    } else {
        sumUb = this->softmaxSumPongBuf.template Get<T>()[sumOffset];
        maxUb = this->softmaxMaxPongBuf.template Get<T>()[sumOffset];
    }

    sumUb.SetShapeInfo(ShapeInfo(2, maxSumShape, DataFormat::ND));
    maxUb.SetShapeInfo(ShapeInfo(2, maxSumShape, DataFormat::ND));

    uint32_t expShape[] = {static_cast<uint32_t>(extraInfo.vecS1BaseSize),
                           static_cast<uint32_t>(fp32BaseSize)};
    LocalTensor<T> expUb;
    int expOffset = sumOffset;
    expUb = this->softmaxExpBuf.template Get<T>()[expOffset];
    expUb.SetShapeInfo(ShapeInfo(2, expShape, DataFormat::ND));

    LocalTensor<uint8_t> apiTmpBuffer = this->commonTBuf.template Get<uint8_t>();
    pipe_barrier(PIPE_V);
    SoftMaxTiling softmaxFlashTilingData;
    if (IsBasicBlockInSoftMax(extraInfo.vecS1BaseSize, this->s2Size)) {
        SoftmaxFlashV2<T, false, true, true>(srcTensor, sumUb, maxUb, srcTensor, expUb, sumUb, maxUb, apiTmpBuffer,
                                             softmaxFlashTilingData);
    } else {
        SoftmaxFlashV2<T, false, true, false>(srcTensor, sumUb, maxUb, srcTensor, expUb, sumUb, maxUb, apiTmpBuffer,
                                              softmaxFlashTilingData);
    }

    if constexpr (implMode == ImplModeEnum::AA_INVALID_LINE_HIGH_PRECISION) {
        SoftMaxShapeInfo softmaxShapeInfo{static_cast<uint32_t>(extraInfo.vecS1BaseSize),
                                          static_cast<uint32_t>(extraInfo.s2AlignSize),
                                          static_cast<uint32_t>(extraInfo.vecS1BaseSize),
                                          this->s2Size};
        AdjustSoftMaxRes<T, T>(srcTensor, maxUb, this->negativeIntScalar, 0.0, softmaxShapeInfo);
        if (loopIdx == this->s1OuterSize - 1 || (loopIdx + 1) % extraInfo.softmaxCopyOutLimit == 0) {
            LocalTensor<T> maxTensor;
            LocalTensor<T> sumTensor;
            if (extraInfo.boIdx % 2 == 0) {
                maxTensor = this->softmaxMaxPingBuf.template Get<T>();
                sumTensor = this->softmaxSumPingBuf.template Get<T>();
            } else {
                maxTensor = this->softmaxMaxPongBuf.template Get<T>();
                sumTensor = this->softmaxSumPongBuf.template Get<T>();
            }
            uint32_t currS1Size = static_cast<uint32_t>(Min(extraInfo.softmaxCopyOutSize * fp32BaseSize,
                                  this->s1Size * fp32BaseSize - extraInfo.softmaxOutOffset));
            currS1Size = CeilDiv(currS1Size, fp32BaseSize);
            SoftMaxShapeInfo softmaxShapeInfo{currS1Size, static_cast<uint32_t>(fp32BaseSize),
                                              currS1Size, static_cast<uint32_t>(fp32BaseSize)};
            AdjustSoftMaxRes<T, T>(sumTensor, maxTensor, this->negativeIntScalar, this->positiveFloatScalar,
                                   softmaxShapeInfo);
        }
    }

    if (loopIdx == this->s1OuterSize - 1 || (loopIdx + 1) % extraInfo.softmaxCopyOutLimit == 0) {
        int64_t bOffset = extraInfo.boIdx * this->biN2GS1 * fp32BaseSize;
        int64_t biN2GOffset = extraInfo.biN2GoIdx * this->s1Size * fp32BaseSize;
        int64_t vS1Offset = extraInfo.softmaxOutOffset;

        int64_t gmOffset = bOffset + biN2GOffset + vS1Offset;
        int64_t calculateSize = Min(extraInfo.softmaxCopyOutSize * fp32BaseSize,
                                        this->s1Size * fp32BaseSize - extraInfo.softmaxOutOffset);
        LocalTensor<float> sumUbStart;
        LocalTensor<float> maxUbStart;
        if (extraInfo.boIdx % 2 == 0) {
            sumUbStart = this->softmaxSumPingBuf.template Get<float>()[0];
            maxUbStart = this->softmaxMaxPingBuf.template Get<float>()[0];
        } else {
            sumUbStart = this->softmaxSumPongBuf.template Get<float>()[0];
            maxUbStart = this->softmaxMaxPongBuf.template Get<float>()[0];
        }

        event_t eventIdVToMte3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
        SetFlag<HardEvent::V_MTE3>(eventIdVToMte3);
        WaitFlag<HardEvent::V_MTE3>(eventIdVToMte3);
        DataCopy(this->softmaxSumGm[gmOffset], sumUbStart, calculateSize);
        DataCopy(this->softmaxMaxGm[gmOffset], maxUbStart, calculateSize);
        event_t eventIdMte3ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_V));
        SetFlag<HardEvent::MTE3_V>(eventIdMte3ToV);
        WaitFlag<HardEvent::MTE3_V>(eventIdMte3ToV);
        extraInfo.softmaxOutOffset += calculateSize;
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    LayoutMode layout>
__aicore__ inline void FlashAttentionScoreSplitBBn2gs1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,  \
    layout>::ComputeDropMask(SplitBExtraInfo &extraInfo, LocalTensor<T> &bmm1ResUb, event_t eventIdMte2ToV)
{
    if constexpr (hasDrop == true) {
        LocalTensor<uint8_t> apiTmpBuffer = this->commonTBuf.template Get<uint8_t>();

        // softmax的计算结果需要执行dropout操作
        LocalTensor<uint8_t> dropMaskUb = this->maskTBufPong.template Get<uint8_t>();
        pipe_barrier(PIPE_V);

        DropOutShapeInfo dropOutShapeInfo;
        dropOutShapeInfo.firstAxis = extraInfo.vecS1BaseSize;
        dropOutShapeInfo.srcLastAxis = extraInfo.s2AlignSize;
        SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
        WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
        if (this->dropMaskUnAligned) {
            dropOutShapeInfo.maskLastAxis = (this->s2Size + blockBytes - 1) / blockBytes * blockBytes;
            DropOut<T, true>(bmm1ResUb, bmm1ResUb, dropMaskUb, apiTmpBuffer, this->tilingData->inputParams.keepProb,
                             dropOutShapeInfo);
        } else {
            dropOutShapeInfo.maskLastAxis = (this->s2Size / 8 + blockBytes - 1) / blockBytes * blockBytes;
            DropOut<T, false, 4>(bmm1ResUb, bmm1ResUb, dropMaskUb, apiTmpBuffer, this->tilingData->inputParams.keepProb,
                                 dropOutShapeInfo);
        }
        return;
    }
}
#endif